﻿// Copyright (c) Microsoft. All rights reserved.

using System.ComponentModel;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Google;
using xRetry;

namespace FunctionCalling;

/// <summary>
/// These examples demonstrate two ways functions called by the Gemini LLM can be invoked using the SK streaming and non-streaming AI API:
///
/// 1. Automatic Invocation by SK (with and without nullable properties):
///    Functions called by the LLM are invoked automatically by SK. The results of these function invocations
///    are automatically added to the chat history and returned to the LLM. The LLM reasons about the chat history
///    and generates the final response.
///    This approach is fully automated and requires no manual intervention from the caller.
///
/// 2. Manual Invocation by a Caller:
///    Functions called by the LLM are returned to the AI API caller. The caller controls the invocation phase where
///    they may decide which function to call, when to call them, how to handle exceptions, call them in parallel or sequentially, etc.
///    The caller then adds the function results or exceptions to the chat history and returns it to the LLM, which reasons about it
///    and generates the final response.
///    This approach is manual and provides more control over the function invocation phase to the caller.
/// </summary>
public sealed class Gemini_FunctionCalling(ITestOutputHelper output) : BaseTest(output)
{
    [RetryFact]
    public async Task GoogleAIChatCompletionWithFunctionCalling()
    {
        Console.WriteLine("============= Google AI - Gemini Chat Completion with function calling =============");

        Assert.NotNull(TestConfiguration.GoogleAI.ApiKey);
        Assert.NotNull(TestConfiguration.GoogleAI.Gemini.ModelId);

        Kernel kernel = Kernel.CreateBuilder()
            .AddGoogleAIGeminiChatCompletion(
                modelId: TestConfiguration.GoogleAI.Gemini.ModelId,
                apiKey: TestConfiguration.GoogleAI.ApiKey)
            .Build();

        await this.RunSampleAsync(kernel);
    }

    [RetryFact]
    public async Task VertexAIChatCompletionWithFunctionCalling()
    {
        Console.WriteLine("============= Vertex AI - Gemini Chat Completion with function calling =============");

        Assert.NotNull(TestConfiguration.VertexAI.BearerKey);
        Assert.NotNull(TestConfiguration.VertexAI.Location);
        Assert.NotNull(TestConfiguration.VertexAI.ProjectId);
        Assert.NotNull(TestConfiguration.VertexAI.Gemini.ModelId);

        Kernel kernel = Kernel.CreateBuilder()
            .AddVertexAIGeminiChatCompletion(
                modelId: TestConfiguration.VertexAI.Gemini.ModelId,
                bearerKey: TestConfiguration.VertexAI.BearerKey,
                location: TestConfiguration.VertexAI.Location,
                projectId: TestConfiguration.VertexAI.ProjectId)
            .Build();

        // To generate bearer key, you need installed google sdk or use Google web console with command:
        //
        //   gcloud auth print-access-token
        //
        // Above code pass bearer key as string, it is not recommended way in production code,
        // especially if IChatCompletionService will be long-lived, tokens generated by google sdk lives for 1 hour.
        // You should use bearer key provider, which will be used to generate token on demand:
        //
        // Example:
        //
        // Kernel kernel = Kernel.CreateBuilder()
        //     .AddVertexAIGeminiChatCompletion(
        //         modelId: TestConfiguration.VertexAI.Gemini.ModelId,
        //         bearerKeyProvider: () =>
        //         {
        //             // This is just example, in production we recommend using Google SDK to generate your BearerKey token.
        //             // This delegate will be called on every request,
        //             // when providing the token consider using caching strategy and refresh token logic when it is expired or close to expiration.
        //             return GetBearerKey();
        //         },
        //         location: TestConfiguration.VertexAI.Location,
        //         projectId: TestConfiguration.VertexAI.ProjectId);

        await this.RunSampleAsync(kernel);
    }

    [RetryFact]
    public async Task GoogleAIFunctionCallingNullable()
    {
        Console.WriteLine("============= Google AI - Gemini Chat Completion with function calling (nullable properties) =============");

        Assert.NotNull(TestConfiguration.GoogleAI.ApiKey);

        var kernelBuilder = Kernel.CreateBuilder()
            .AddGoogleAIGeminiChatCompletion(
                modelId: TestConfiguration.VertexAI.Gemini.ModelId,
                apiKey: TestConfiguration.GoogleAI.ApiKey);

        kernelBuilder.Plugins.AddFromType<MyWeatherPlugin>();

        var promptExecutionSettings = new GeminiPromptExecutionSettings()
        {
            FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(),
        };

        var kernel = kernelBuilder.Build();

        var response = await kernel.InvokePromptAsync("Hi, what's the weather in New York?", new(promptExecutionSettings));

        Console.WriteLine(response.ToString());
    }

    private sealed class MyWeatherPlugin
    {
        [KernelFunction]
        [Description("Get the weather for a given location.")]
        private string GetWeather(WeatherRequest request)
        {
            return $"The weather in {request?.Location} is sunny.";
        }
    }

    [RetryFact]
    public async Task VertexAIFunctionCallingNullable()
    {
        Console.WriteLine("============= Vertex AI - Gemini Chat Completion with function calling (nullable properties) =============");

        Assert.NotNull(TestConfiguration.VertexAI.BearerKey);
        Assert.NotNull(TestConfiguration.VertexAI.Location);
        Assert.NotNull(TestConfiguration.VertexAI.ProjectId);

        var kernelBuilder = Kernel.CreateBuilder()
            .AddVertexAIGeminiChatCompletion(
                modelId: TestConfiguration.VertexAI.Gemini.ModelId,
                bearerKey: TestConfiguration.VertexAI.BearerKey,
                location: TestConfiguration.VertexAI.Location,
                projectId: TestConfiguration.VertexAI.ProjectId);

        // To generate bearer key, you need installed google sdk or use Google web console with command:
        //
        //   gcloud auth print-access-token
        //
        // Above code pass bearer key as string, it is not recommended way in production code,
        // especially if IChatCompletionService will be long-lived, tokens generated by google sdk lives for 1 hour.
        // You should use bearer key provider, which will be used to generate token on demand:
        //
        // Example:
        //
        // Kernel kernel = Kernel.CreateBuilder()
        //     .AddVertexAIGeminiChatCompletion(
        //         modelId: TestConfiguration.VertexAI.Gemini.ModelId,
        //         bearerKeyProvider: () =>
        //         {
        //             // This is just example, in production we recommend using Google SDK to generate your BearerKey token.
        //             // This delegate will be called on every request,
        //             // when providing the token consider using caching strategy and refresh token logic when it is expired or close to expiration.
        //             return GetBearerKey();
        //         },
        //         location: TestConfiguration.VertexAI.Location,
        //         projectId: TestConfiguration.VertexAI.ProjectId);

        kernelBuilder.Plugins.AddFromType<MyWeatherPlugin>();

        var promptExecutionSettings = new GeminiPromptExecutionSettings()
        {
            FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(),
        };
        var kernel = kernelBuilder.Build();
        var response = await kernel.InvokePromptAsync("Hi, what's the weather in New York?", new(promptExecutionSettings));
        Console.WriteLine(response.ToString());
    }

    private async Task RunSampleAsync(Kernel kernel)
    {
        // Add a plugin with some helper functions we want to allow the model to utilize.
        kernel.ImportPluginFromFunctions("HelperFunctions",
        [
            kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."),
            kernel.CreateFunctionFromMethod((string cityName) =>
                cityName switch
                {
                    "Boston" => "61 and rainy",
                    "London" => "55 and cloudy",
                    "Miami" => "80 and sunny",
                    "Paris" => "60 and rainy",
                    "Tokyo" => "50 and sunny",
                    "Sydney" => "75 and sunny",
                    "Tel Aviv" => "80 and sunny",
                    _ => "31 and snowing",
                }, "Get_Weather_For_City", "Gets the current weather for the specified city"),
        ]);

        Console.WriteLine("======== Example 1: Use automated function calling with a non-streaming prompt ========");
        {
            GeminiPromptExecutionSettings settings = new() { ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions };
            Console.WriteLine(await kernel.InvokePromptAsync(
                "Check current UTC time, and return current weather in Paris city", new(settings)));
            Console.WriteLine();
        }

        Console.WriteLine("======== Example 2: Use automated function calling with a streaming prompt ========");
        {
            GeminiPromptExecutionSettings settings = new() { ToolCallBehavior = GeminiToolCallBehavior.AutoInvokeKernelFunctions };
            await foreach (var update in kernel.InvokePromptStreamingAsync(
                               "Check current UTC time, and return current weather in Boston city", new(settings)))
            {
                Console.Write(update);
            }

            Console.WriteLine();
        }

        Console.WriteLine("======== Example 3: Use manual function calling with a non-streaming prompt ========");
        {
            var chat = kernel.GetRequiredService<IChatCompletionService>();
            var chatHistory = new ChatHistory();

            GeminiPromptExecutionSettings settings = new() { ToolCallBehavior = GeminiToolCallBehavior.EnableKernelFunctions };
            chatHistory.AddUserMessage("Check current UTC time, and return current weather in London city");
            while (true)
            {
                var result = (GeminiChatMessageContent)await chat.GetChatMessageContentAsync(chatHistory, settings, kernel);

                if (result.Content is not null)
                {
                    Console.Write(result.Content);
                }

                if (result.ToolCalls is not { Count: > 0 })
                {
                    break;
                }

                chatHistory.Add(result);
                foreach (var toolCall in result.ToolCalls)
                {
                    KernelArguments? arguments = null;
                    if (kernel.Plugins.TryGetFunction(toolCall.PluginName, toolCall.FunctionName, out var function))
                    {
                        // Add parameters to arguments
                        if (toolCall.Arguments is not null)
                        {
                            arguments = [];
                            foreach (var parameter in toolCall.Arguments)
                            {
                                arguments[parameter.Key] = parameter.Value?.ToString();
                            }
                        }
                    }
                    else
                    {
                        Console.WriteLine("Unable to find function. Please try again!");
                        continue;
                    }

                    var functionResponse = await function.InvokeAsync(kernel, arguments);
                    Assert.NotNull(functionResponse);

                    var calledToolResult = new GeminiFunctionToolResult(toolCall, functionResponse);

                    chatHistory.Add(new GeminiChatMessageContent(calledToolResult));
                }
            }

            Console.WriteLine();
        }

        /* Uncomment this to try in a console chat loop.
        Console.WriteLine("======== Example 4: Use automated function calling with a streaming chat ========");
        {
            GeminiPromptExecutionSettings settings = new() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
            var chat = kernel.GetRequiredService<IChatCompletionService>();
            var chatHistory = new ChatHistory();

            while (true)
            {
                Console.Write("Question (Type \"quit\" to leave): ");
                string question = Console.ReadLine() ?? string.Empty;
                if (question == "quit")
                {
                    break;
                }

                chatHistory.AddUserMessage(question);
                System.Text.StringBuilder sb = new();
                await foreach (var update in chat.GetStreamingChatMessageContentsAsync(chatHistory, settings, kernel))
                {
                    if (update.Content is not null)
                    {
                        Console.Write(update.Content);
                        sb.Append(update.Content);
                    }
                }

                chatHistory.AddAssistantMessage(sb.ToString());
                Console.WriteLine();
            }
        }
        */
    }

    private sealed class WeatherRequest
    {
        public string? Location { get; set; }
    }
}
